# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Use Generative AI to supplement Airflow DAGs with doc_md values.
"""

import os
import logging
import argparse

import ast

from vertexai.generative_models import (
    GenerativeModel,)

logging.basicConfig(level=logging.INFO)
BASE_PROMPT = "You're a Senior Data Engineer and an Apache Airflow expert."


def generate_tags(description: str):
    """
    Use Google Gemini Pro model to generate airflow tags list.
    """

    model = GenerativeModel("gemini-1.0-pro")

    response = model.generate_content([
        f"{BASE_PROMPT} I will give you an Airflow DAG Description. Your task is to create a list of tags that categorize the DAG. The list must have at least 3 tags and you must respond in the following format: ['','','']",  # pylint: disable=line-too-long
        description,
    ])
    if len(response.candidates[0].content.parts) > 0:
        list_start_index = response.text.find("[")
        tags = response.text[list_start_index:]
        print(tags)

        tags = ast.literal_eval(response.text)

        logging.info(str(tags))
        return tags
    return ""


def generate_desc(file_contents: str):
    """
    Use Google Gemini Pro model to generate airflow tags list.
    """

    model = GenerativeModel("gemini-1.0-pro")

    response = model.generate_content([
        f"{BASE_PROMPT} Generate a single sentence description for this Airflow DAG:",
        file_contents,
    ])

    if len(response.candidates[0].content.parts) > 0:
        logging.info(response.text)
        return response.text

    return ""


def generate_docs(file_contents: str):
    """
    Use Google Gemini Pro model to generate documentation in markdown.
    """

    model = GenerativeModel("gemini-1.0-pro")

    response = model.generate_content([
        f"{BASE_PROMPT} Create a detailed markdown description of the following Airflow DAG and what it does:",  # pylint: disable=line-too-long
        file_contents,
    ])

    if len(response.candidates[0].content.parts) > 0:

        docs = """
{doc}

---

*Generated by* 

<img src="https://upload.wikimedia.org/wikipedia/commons/8/8a/Google_Gemini_logo.svg" alt="Google Gemini" width="85"/>
    """.format(doc=response.text)
        docs = docs.replace("{", "")
        docs = docs.replace("}", "")

        return docs
    return ""


def generate_dag_metadata(src_dir, tgt_dir):  #pylint: disable=too-many-locals
    """
    Loop though a DAG folder. Supplement DAGs with Generative AI for doc_md, tags, and description
    attributes. Will only supplement if DAG does not have existing values. Write supplemented
    DAG to new .py files in a target directory.
    """

    for file in os.listdir(src_dir):  # pylint: disable=too-many-nested-blocks
        if ".py" in file:

            file_path = src_dir + file
            logging.info("Generating metadata for: %s", file_path)
            file_contents = open(file_path).read()

            docs = generate_docs(file_contents)
            desc = generate_desc(file_contents)
            tags = generate_tags(desc)

            tree = ast.parse(open(file_path).read())

            for node in ast.walk(tree):
                if isinstance(node, ast.Call):
                    if (isinstance(node.func, ast.Attribute) and
                            node.func.attr == "DAG") or (isinstance(
                                node.func, ast.Name) and node.func.id == "DAG"):
                        doc_found = False
                        tag_found = False
                        desc_found = False

                        for keyword in node.keywords:
                            if keyword.arg == "doc_md":
                                doc_found = True
                            if keyword.arg == "tags":
                                tag_found = True
                            if keyword.arg == "description":
                                desc_found = True
                        if not desc_found:
                            new_keyword = ast.keyword(
                                arg="description",
                                value=ast.Constant(value=desc))
                            node.keywords.append(new_keyword)
                        if not tag_found:
                            new_keyword = ast.keyword(
                                arg="tags", value=ast.Constant(value=tags))
                            node.keywords.append(new_keyword)
                        if not doc_found:
                            new_keyword = ast.keyword(
                                arg="doc_md", value=ast.Constant(value=docs))
                            node.keywords.append(new_keyword)

                modified_code = ast.unparse(tree)

                filename = os.path.split(file_path)[1]

                with open(tgt_dir + filename, "w") as file:
                    file.write(modified_code)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Supplement DAGs with doc_md values.")
    parser.add_argument("--src",
                        type=str,
                        help="The source dir containing existing airflow DAGs")
    parser.add_argument("--tgt",
                        type=str,
                        help="The target dir for newly documented airflow DAGs")

    args = parser.parse_args()
    generate_dag_metadata(args.src, args.tgt)
